PyTorch Ray Double Free 분석
분산 PyTorch 및 Ray Tune 환경에서의 메모리 충돌 분석: Double Free에서 워커 OOM까지
TL;DR
free(): double free or corruption detected in tcache 2— 이 에러는 단순한 코드 버그가 아니다. **컨테이너의 부족한 ****/dev/shm**이 Ray와 PyTorch의 메모리 관리를 불안정하게 만들고, PyTorch 연산 그래프의 암묵적 누수가 겹치면서 발생하는 시스템적 실패다. 환경 설정(--shm-size)을 먼저 바로잡고, 코드에서는 내부 상태의detach()+loss.item()패턴으로 그래프 누수를 차단하면 해결된다.
이 에러를 본 적 있는가?
Ray Tune으로 하이퍼파라미터 탐색을 돌리던 중, 아무런 예고 없이 프로세스가 죽으면서 터미널에 이런 메시지가 뜬다.
free(): double free detected in tcache 2
Aborted (core dumped)스택 트레이스도 없고, Python 예외도 아니다. 그냥 죽었다. 혹은 이런 형태로 나타나기도 한다.
ray.exceptions.RayActorError: The actor died unexpectedly before finishing this task.
Worker exit type: UNEXPECTED_SYSTEM_EXIT
connection error code 2. End of file그리고 로그를 조금 위로 올려보면 이런 경고가 조용히 찍혀 있다.
WARNING: The object store is using /tmp instead of /dev/shm
because /dev/shm has only 67108864 bytes available.이 글은 이 메시지들 사이의 인과 관계를 추적한 분석이다. OS/컨테이너 환경에서 출발하여 Ray, PyTorch 스택을 거쳐 실제 코드 수정까지, 상향식으로 전개한다.
1. 충돌의 해부: 에러 메시지가 말해주는 것
1.1. 힙 손상(Heap Corruption)의 정체
double free or corruption은 C 표준 라이브러리의 메모리 할당자(malloc / jemalloc / tcmalloc)가 발생시키는 에러다. 프로그램이 이미 해제된 포인터를 다시 free() 하려 했거나, 메모리 블록의 관리 메타데이터가 손상되었다는 뜻이다 (참고: TensorFlow #6968).
이 에러의 디버깅이 악명 높은 이유는 실제 손상 시점과 감지 시점이 다르기 때문이다. 메모리 블록 A가 잘못 해제되는 일은 수천 줄 전에 일어났을 수 있지만, 할당자가 이를 감지하는 것은 한참 뒤 블록 A의 메타데이터를 다시 참조할 때다.
1.2. RayActorError와 connection error code 2의 실체
RayActorError는 HPO Trial을 담당하는 Ray 워커 프로세스 내부에서 치명적 에러가 발생했다는 뜻이다. UNEXPECTED_SYSTEM_EXIT은 Ray의 통제 밖에서 프로세스가 강제 종료되었음을 의미한다.
connection error code 2. End of file은 네트워크 문제처럼 보이지만, 실제로는 워커 프로세스가 OOM Killer에 의해 SIGKILL로 강제 종료되면서 Raylet과의 연결이 끊긴 것이다. 인과 관계를 정리하면 이렇다:
- 학습 루프에서 해제되지 않은 메모리가 누적된다
- 시스템 메모리 압박이 발생한다
- Linux OOM Killer가 SIGKILL을 보낸다
- 프로세스가 즉시 종료되면서 Raylet 연결이 끊긴다
- Ray가
End of file→RayActorError를 전파한다 traceback이loss.backward()에서 멈춰 있다면, 이는 그래디언트 계산이 메모리 급증 지점이라는 강력한 신호다.
1.3. Python인데 왜 C 레벨 에러가 나올까
NumPy 배열이나 PyTorch 텐서 같은 고수준 Python 객체는, 실제로는 **C/C++ 코드가 할당하고 관리하는 거대한 메모리 블록을 감싸는 래퍼(wrapper)**에 불과하다. Ray, PyTorch, 시스템 라이브러리 등 여러 네이티브 구성 요소가 동시에 메모리를 다루면서, 특정 메모리 조각의 "소유권"과 해제 책임에 대해 서로 다른 가정을 하게 되는 것이 double free의 주된 경로다.
흥미로운 증거가 하나 있다. 메모리 할당자를 tcmalloc으로 교체(LD_PRELOAD)하면 문제가 "해결"되는 사례가 보고되어 있다 (참고: TensorFlow #6968). 하지만 이것은 근본 원인을 해결한 것이 아니다. 할당자의 내부 동작이 바뀌면서, 진행 중인 특정 유형의 손상에 "우연히" 덜 취약해진 것뿐이다. 견고한 애플리케이션은 어떤 할당자를 쓰든 올바르게 동작해야 한다.
1.4. 증상과 근본 원인을 분리해서 생각하기
디버깅의 핵심 사고방식: 충돌 로그는 증상이고, 근본 원인은 메모리 생명주기 관리의 논리 오류다. 이 글의 목표는 고수준 로직에서 저수준 충돌까지의 인과 관계 사슬을 완성하는 것이다.
2. 환경이 먼저: /dev/shm과 컨테이너의 함정
2.1. /dev/shm이란 무엇이고, 왜 중요할까?
/dev/shm은 전적으로 RAM에 상주하는 tmpfs 마운트다. 동일 머신의 프로세스들이 커널을 거치지 않고 대량의 데이터를 공유하는 가장 빠른 방법, 즉 POSIX 공유 메모리의 기반이다.
2.2. Ray와 PyTorch가 /dev/shm에 의존하는 방식
Ray의 핵심인 **객체 저장소(Object Store)**는 /dev/shm 내에 거대한 메모리 매핑 파일을 생성하여 동일 노드의 워커들이 객체를 zero-copy로 읽을 수 있게 한다. Ray는 기본적으로 가용 메모리의 30%를 객체 저장소에 할당한다 (참고: Ray Memory Management).
PyTorch의 torch.multiprocessing도 DataLoader의 num_workers > 0 설정 시 프로세스 간 텐서 전달을 위해 shm_open을 통해 /dev/shm에 파일을 생성한다 (참고: PyTorch Multiprocessing). PyTorch의 공유 메모리 전략은 두 가지가 있다:
file_descriptor(Linux 기본값):shm_open의 파일 디스크립터를 UNIX 소켓으로 전달. 효율적이지만 동시 열린 fd 수(ulimit -n)에 제약을 받는다.file_system(macOS/Windows 기본값): 파일 이름으로 공유 메모리를 식별. fd 제한에 강하지만, 프로세스 충돌 시 파일이 정리되지 않을 수 있다. PyTorch는 이를 위해torch_shm_manager데몬을 생성한다. 즉, 두 프레임워크가 동일한 제한된 리소스를 두고 경쟁하는 구조다.
2.3. 경고 메시지가 말해주는 것
WARNING: The object store is using /tmp instead of /dev/shm
because /dev/shm has only 67108864 bytes available.이 메시지는 시스템이 고성능 RAM 기반 IPC에서 디스크 기반 IPC(/tmp)로 강제 전환되었음을 의미한다. 결과는 두 가지다:
- 성능 저하: 객체 저장/로딩이 현저히 느려진다.
- 안정성 저하: 메모리 관리 메커니즘 자체가 바뀐다. 순수한 메모리 연산 대신 파일 I/O, 파일 핸들, 디스크 공간을 다루게 되면서 새로운 실패 모드가 열린다. 프로세스가 비정상 종료될 때 임시 파일이 정리되지 않아 후속 작업이 더 빠르게 실패하는 자원 고갈의 악순환이 시작될 수 있다.
이 문제의 주범은 Docker와 WSL이 기본으로 할당하는
/dev/shm크기가 64MB에 불과하다는 것이다 (참고: Stack Overflow). Ray나 PyTorch 같은 데이터 집약적 애플리케이션에는 턱없이 부족하다.
2.4. 진단과 수정
진단:
df -h /dev/shm수정 (Docker):
docker run --shm-size=4g ...수정 (docker-compose):
services:
training:
shm_size: '4g'수정 (Kubernetes):
volumes:
- name: dshm
emptyDir:
medium: Memory
sizeLimit: "4Gi"Ray 공식 문서는 컨테이너 가용 RAM의 30% 이상을 /dev/shm에 할당할 것을 권장한다[4]. 환경을 안정화하면, 예측 불가능한 시스템 문제에서 예측 가능한 애플리케이션 수준의 버그로 문제의 범위를 좁힐 수 있다.
3. Ray의 메모리 관리와 OOM 방지 메커니즘
3.1. Ray의 메모리 아키텍처
Ray의 메모리는 크게 세 영역으로 나뉜다:
- 워커 힙(Worker Heap): 각 워커 프로세스가 사용하는 일반 힙 메모리 (
RSS - SHR로 측정) - 객체 저장소 메모리(Object Store Memory):
ray.put()이나 태스크 반환값이 저장되는 공간 - 객체 저장소 공유 메모리: 위 객체들이 다른 프로세스와 공유되는 부분
Ray는
ObjectRef에 대해 분산 참조 카운팅을 사용한다. 클러스터 어디에서든 참조가 하나라도 남아 있으면 객체는 저장소에 고정(pin)된다. 참조를 명시적으로 해제하지 않으면 이것이 "누수"의 원인이 된다.
3.2. Ray Tune Trial의 생명주기와 "메모리 크립"
Ray Tune은 Trial마다 리소스를 할당하고, Trainable의 setup() -> step() -> cleanup()을 실행한 뒤 해제한다. 여기서 흔히 겪는 **"메모리 크립(memory creep)"**은 Trial이 종료된 후에도 GPU 메모리나 워커 힙이 완전히 해제되지 않는 현상이다.
이 패턴은 "두 번째 Trial마다 실패"하는 증상으로 나타나기도 한다 (참고: Stack Overflow). 첫 번째 Trial의 잔여 메모리 위에 두 번째 Trial이 쌓이면서 임계점을 넘는 것이다.
cleanup 메서드에서 del, gc.collect(), torch.cuda.empty_cache()를 호출하는 것은 모범 사례다. 하지만 Trial 내에서 생성된 객체에 대한 참조가 외부로 "탈출"한 경우 — 중앙 로거, Tuner 반환값, 백그라운드 스레드 등 — gc.collect()는 해당 메모리를 회수할 수 없다. cleanup은 최종 방어선이지 1차 방어선이 아니다.
3.3. OOM 모니터: 안전장치와 그 한계
Ray에는 노드 수준의 메모리 사용량을 추적하는 OOM 방지 모니터가 내장되어 있다(Linux 전용, cgroup v1/v2 지원) (참고: Ray OOM Prevention).
| 환경 변수 | 기본값 | 설명 |
|---|---|---|
RAY_memory_usage_threshold | 0.95 | 시스템 메모리 사용률이 이 값을 초과하면 워커를 선제 종료 |
RAY_memory_monitor_refresh_ms | 250 | 메모리 점검 주기 (ms). 0으로 설정하면 비활성화 |
사용량이 임계값을 초과하면, OS의 OOM kill 전에 메모리를 가장 많이 사용하는 워커부터 선제적으로 종료시킨다. 이때 ray.exceptions.OutOfMemoryError라는 명확한 에러 메시지가 출력된다.
그런데 실제 사례에서 double free나 connection error code 2가 뜨고 OutOfMemoryError는 뜨지 않았다면, 메모리 증가가 OOM 모니터의 250ms 폴링 간격보다 빠르게 진행되었거나, 손상이 힙 크기가 아닌 메모리 할당 메타데이터 수준에서 발생했을 가능성이 높다. 이는 Ray의 안전장치조차 제어된 종료를 수행할 수 없을 만큼 빠르게 진행되는, 심각한 C++ 레벨 메모리 버그를 시사한다.
4. PyTorch 연산 그래프의 함정
4.1. Autograd 그래프가 메모리를 잡아먹는 메커니즘
requires_grad=True인 텐서에 연산이 수행될 때마다 PyTorch는 grad_fn을 첨부하여 리프 노드까지의 **연산 연결 리스트(computational graph)**를 생성한다. 이 그래프는 backward() 호출 시 그래디언트를 계산하는 데 사용되며, 이 그래프 자체가 메모리를 점유하는 주체다.
이상적인 학습 루프에서 그래프의 생명주기는 이렇다:
- 순전파: 그래프가 생성된다
loss.backward(): 그래프를 순회하며 그래디언트를 계산한다optimizer.step(): 파라미터를 업데이트한다- 그래프가 해제된다
- 다음 반복에서 새 그래프가 생성된다 문제는 3~4단계에서 그래프가 제대로 해제되지 않을 때 발생한다.
4.2. 전형적인 누수 시나리오: 손실 누적
# 잘못된 코드 -- 매 반복마다 연산 그래프가 확장된다
total_loss = 0
for batch in dataloader:
loss = model(batch)
total_loss += loss # loss의 grad_fn이 total_loss에 연결된다
loss.backward()loss가 grad_fn을 가진 상태에서 total_loss += loss를 수행하면, 매 반복마다 연산 그래프가 이전 반복의 그래프와 연결되면서 확장된다. 10번째 반복의 total_loss는 1~10번째 반복의 전체 연산 기록을 메모리에 들고 있게 된다.
# 올바른 코드 -- .item()으로 스칼라를 추출하여 그래프 연결을 끊는다
total_loss = 0
for batch in dataloader:
loss = model(batch)
loss.backward()
total_loss += loss.item() # Python float로 변환, 그래프와 무관.item()은 텐서에서 Python 스칼라 값만 추출하므로 연산 그래프와의 연결이 완전히 끊어진다.
4.3. 그래프 관리 도구: detach(), clone(), no_grad() 정확히 구분하기
이 세 가지는 자주 혼동되지만, 각각의 역할이 명확히 다르다.
with torch.no_grad(): -- 스코프 내 모든 연산의 그래디언트 추적을 비활성화하는 컨텍스트 관리자. 텐서 자체를 변경하지 않지만, 새로운 연산이 그래프에 추가되지 않는다. 추론/검증 루프에 적합하다.
with torch.no_grad():
predictions = model(test_input) # grad_fn이 생성되지 않는다.detach() -- 텐서 메서드. 원본과 동일한 데이터를 공유하지만 grad_fn이 None인 새 텐서를 반환한다. 연산 그래프에서 텐서를 "잘라낸다".
feature = encoder(x).detach() # feature는 encoder의 그래프와 분리됨
output = decoder(feature) # decoder의 그래프만 생성됨.clone() -- 텐서 메서드. 데이터의 복사본과 연산 기록의 복사본을 가진 새 텐서를 반환한다. 복제된 텐서로의 역전파는 원본으로도 전파된다.
loss_clone = loss.clone()
loss_clone.backward() # loss 원본으로도 그래디언트가 전파된다| 연산 | 데이터 공유 | 그래프 기록 공유 | requires_grad | 원본으로 그래디언트 전파 | 대표 사용 사례 |
|---|---|---|---|---|---|
x.detach() | 예 | 아니요 | 원본과 동일 | 아니요 | 그래프에서 분리하여 메트릭 계산 등에 사용 |
x.clone() | 아니요 | 예 | 원본과 동일 | 예 | 그래프를 분기하되 원본으로 그래디언트 전파가 필요할 때 |
x.detach().clone() | 아니요 | 아니요 | 원본과 동일 | 아니요 | 연산 기록 없이 데이터만 안전하게 복사 (상태 전이) |
copy.deepcopy(x) | 아니요 | 아니요 | 원본과 동일 | 아니요 | 텐서와 연산 기록까지 완전히 독립적인 복사본 생성 |
with torch.no_grad(): | -- | -- | -- | 아니요 | 추론/검증에서 블록 내 모든 그래디언트 추적 중단 |
5. 실전 사례: FirstOrderPlant의 은닉된 그래프 누수
5.1. 겉보기에는 올바른 코드
PID 제어기(PIDNet)를 학습시키는 루프에서, current_val은 .detach()로 적절히 처리되어 있었다.
current_val = next_val.detach()이 코드 자체는 올바르다. 그런데 왜 메모리가 폭증할까?
5.2. 진짜 범인: Plant 객체의 내부 상태
문제는 FirstOrderPlant 클래스 내부에 있었다.
class FirstOrderPlant:
def __init__(self, ku, tau):
self.ku = ku
self.tau = tau
self.y = torch.tensor(0.0)
def step(self, u, dt):
next_y = self.y + (self.ku * u - self.y) / self.tau * dt
self.y = next_y # 여기가 누수 지점
return self.y인과 관계를 추적하면:
plant.step(u, dt)호출 -- 입력u는 PIDNet의 출력이므로 연산 기록을 가진 텐서다next_y는u로부터 유도되므로,u의 연산 그래프와 연결된다self.y = next_y-- 내부 상태가 그래프의 일부가 된다- 다음 step에서
self.y를 기반으로 새로운next_y를 계산한다 - 이전 step의 그래프가 새 step의 그래프에 연결된다
- 이 연결이 반복마다 누적되어 전체 학습 이력이 하나의 거대한 그래프로 연결된다
- 마지막
loss.backward()시 이 전체 그래프를 순회하면서 메모리가 폭발한다 주석에 "state.detach() is handled internally"라고 적혀 있더라도, 실제로.detach()가 빠져 있으면 이 누수가 발생한다. 이것이 이 사례의 핵심이다: 그래프 누수 경로가 사용자 코드가 아닌 시뮬레이션 객체 내부에 은닉되어 있었다.
5.3. 수정된 코드
class FirstOrderPlant:
def __init__(self, ku, tau):
self.ku = ku
self.tau = tau
self.y = torch.tensor(0.0)
def reset(self):
self.y = torch.tensor(0.0)
def step(self, u, dt):
next_y = self.y + (self.ku * u - self.y) / self.tau * dt
self.y = next_y.detach() # 그래프 분리: 내부 상태는 그래프를 들고 있지 않는다
return next_y # 반환값은 그래프를 유지하여 backward()에 사용핵심은 반환값과 내부 상태의 역할을 분리하는 것이다:
return next_y-- 현재 step의 그래프를 유지하여loss.backward()에 사용된다self.y = next_y.detach()-- 다음 step의 초기값으로만 쓰이므로 그래프 연결이 필요 없다
5.4. 누수 없는 전체 학습 루프
def pid_training_loop(config, steps: int = 20):
lr = config["lr"]
kp, ki, kd = config["kp"], config["ki"], config["kd"]
pid_net = PIDNet(kp, ki, kd)
plant = FirstOrderPlant(ku=2.0, tau=5.0)
optimizer = torch.optim.Adam(pid_net.parameters(), lr=lr)
total_loss = 0.0
current_val = torch.tensor(0.0)
target_val = torch.tensor(10.0)
dt = torch.tensor(0.1)
pid_net.reset()
plant.reset()
for _ in range(steps):
u = pid_net(current_val, target_val, dt)
next_val = plant.step(u, dt)
loss = (target_val - next_val) ** 2
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() # 스칼라만 누적
current_val = next_val.detach() # 다음 반복을 위한 깨끗한 텐서
return total_loss / steps5.5. 누수 여부를 정량적으로 확인하는 방법
연산 그래프 누수가 의심될 때, psutil로 프로세스 RSS를 추적하면 명확히 확인할 수 있다.
import psutil, os
def pid_training_loop(config, steps=20):
process = psutil.Process(os.getpid())
print(f"Trial 시작. 초기 RAM: {process.memory_info().rss / 1024**2:.2f} MB")
for i in range(steps):
# ... 학습 로직 ...
if (i + 1) % 5 == 0:
print(f" step {i+1}, RAM: {process.memory_info().rss / 1024**2:.2f} MB")step이 진행될수록 RAM이 꾸준히 증가한다면 메모리 누수가 확실하다. GPU 메모리는 torch.cuda.memory_allocated()로 같은 방식으로 추적할 수 있다.
6. 일반적인 PyTorch 메모리 누수 패턴
위 사례는 "커스텀 모듈 내부 상태" 패턴의 한 예다. 실무에서 자주 마주치는 그래프 누수 패턴을 정리하면 다음과 같다.
| 패턴 | 문제의 코드 | 근본 원인 | 올바른 코드 |
|---|---|---|---|
| 손실 누적 | total_loss += loss | total_loss가 loss의 그래프를 계속 참조 | total_loss += loss.item() |
| 리스트에 텐서 저장 | history.append(output) | 리스트가 output의 그래프를 참조 | history.append(output.detach()) |
| 커스텀 모듈 내부 상태 | self.state = new_state | 상태 변수가 그래프 연속성의 원인 | self.state = new_state.detach() |
| 로깅용 딕셔너리 | log["out"] = model(x) | 딕셔너리가 그래프 연결된 텐서를 보유 | log["out"] = model(x).detach().cpu() |
| step 간 상태 전이 | self.current_val = next_val | 이전 step의 그래프가 다음 step으로 전파 | self.current_val = next_val.detach().clone() |
공통 원칙: 학습 루프 외부로 나가는 텐서는 반드시 그래프 연결을 끊어야 한다. .item(), .detach(), .detach().clone() 중 상황에 맞는 것을 선택한다.
7. Ray 구성 전략: 디버깅에서 분산 확장까지
7.1. 디버깅: Trial 내부 문제 추적하기
Ray 환경에서 각 Trial은 별도 프로세스로 실행되므로, IDE 디버거(breakpoint, PDB 등)가 기본적으로 작동하지 않는다.
방법 1: Ray 없이 직접 실행
가장 단순하고 효과적인 디버깅 방법은 Ray를 거치지 않고 학습 함수를 직접 호출하는 것이다.
# 디버깅 시
result = pid_training_loop({"lr": 0.01, "kp": 1.0, "ki": 0.1, "kd": 0.05})이렇게 하면 일반적인 Python 디버깅 도구를 모두 사용할 수 있다.
**방법 2: ****FailureConfig**로 에러 즉시 노출
from ray.tune import RunConfig, FailureConfig
tuner = tune.Tuner(
trainable,
run_config=RunConfig(
failure_config=FailureConfig(fail_fast="raise")
),
# ...
)fail_fast="raise" 설정은 Trial에서 예외가 발생하면 즉시 상위로 전파하여, 에러가 조용히 묻히는 것을 방지한다.
방법 3: Ray Distributed Debugger (VS Code)
Ray 2.39+에서는 VS Code 확장(ray-distributed-debugger)을 통해 분산 환경에서도 breakpoint()를 사용할 수 있다. pip install "ray[default]" debugpy를 설치하고, 원격 태스크 내에서 breakpoint()를 호출하면 VS Code가 해당 프로세스에 어태치된다 (참고: Ray Distributed Debugger).
참고:
ray.init(local_mode=True)는 과거에 디버깅용으로 권장되었으나, 현재 deprecated 상태이며 향후 제거될 예정이다. 실험적 기능으로 유지보수되지 않으므로 새 프로젝트에서는 사용을 피하는 것이 좋다 (참고: ray.init() API).
7.2. 병렬 HPO 실행
디버깅이 완료되면 전체 CPU 코어를 활용한 병렬 실행으로 전환한다.
import os
from ray import tune
from ray.tune.search.optuna import OptunaSearch
def run_hpo():
ray.shutdown()
num_cpus = os.cpu_count()
ray.init(num_cpus=num_cpus, ignore_reinit_error=True)
search_space = {
"lr": tune.loguniform(1e-4, 1e-1),
"kp": tune.uniform(0.1, 5.0),
"ki": tune.uniform(0.01, 2.0),
"kd": tune.uniform(0.01, 2.0),
}
# pid_training_loop를 Trainable로 감싸는 래퍼
def objective(config):
loss = pid_training_loop(config)
return {"loss": loss}
trainable = tune.with_resources(objective, {"cpu": 1})
tuner = tune.Tuner(
trainable,
param_space=search_space,
tune_config=tune.TuneConfig(
search_alg=OptunaSearch(metric="loss", mode="min"),
num_samples=50,
),
)
results = tuner.fit()tune.with_resources()로 Trial당 리소스를 지정하면 동시 실행 Trial 수가 floor(총 CPU / Trial당 CPU)로 결정된다. CPU 8코어에 Trial당 1코어면 8개 Trial이 동시에 실행된다.
7.3. Ray 실행 모드 비교
| 구성 요소 | 디버깅 (직접 실행) | 단일 워커 모드 | 병렬 HPO 모드 |
|---|---|---|---|
| 초기화 | Ray 불필요 | ray.init(num_cpus=1) | ray.init(num_cpus=os.cpu_count()) |
| 실행 방식 | 단일 프로세스, 직접 호출 | 다중 프로세스, 순차 실행 | 다중 프로세스, 병렬 실행 |
| 디버깅 편의성 | 최상 -- 모든 디버거 지원 | 낮음 | 낮음 -- 분산 디버거 필요 |
| 추천 시나리오 | 메모리 누수, 로직 오류 추적 | 간단 로그 확인 | 대규모 탐색 시 처리량 극대화 |
8. 디버깅 전략과 환경 설정 체크리스트
8.1. 상향식(Bottom-Up) 디버깅
직관과 달리, Python 코드가 아니라 환경 검증부터 시작하는 것이 가장 효율적이다. 문제는 종종 스택의 가장 낮은 계층에 있고, 그 계층일수록 수정도 간단하다.
| 단계 | 항목 | 검사 방법 및 조치 |
|---|---|---|
| 1단계: 환경 | /dev/shm 크기 | df -h /dev/shm으로 확인. 부족하면 --shm-size로 증설 |
| 컨테이너 메모리 제한 | docker stats로 사용량/한계 확인 | |
| 시스템 메모리 | free -h, htop으로 전체 메모리 및 스왑 확인 | |
| 2단계: 격리 | Ray 없이 재현 | 학습 함수를 직접 호출하여 재현 여부 확인 |
| 최소 Trainable | time.sleep만 있는 Trainable로 재현 여부 확인 | |
| 최소 HPO 실행 | num_samples=1로 재현 여부 확인 | |
| 3단계: 모니터링 | Ray 메모리 | ray memory로 누수된 ObjectRef 확인 |
| Ray 대시보드 | http://127.0.0.1:8265에서 메모리 추세 확인 | |
| PyTorch 메모리 | psutil로 RSS 추적, torch.cuda.memory_allocated() 로깅 | |
| 4단계: 코드 | 텐서 누적 | 모든 루프의 텐서 누적이 .detach() 또는 .item() 사용하는지 검토 |
| 내부 상태 | 커스텀 모듈의 상태 업데이트에 .detach() 적용 여부 검토 | |
no_grad 사용 | 비훈련 코드(검증, 평가)를 with torch.no_grad():로 감쌌는지 확인 |
8.2. 실행 안정성을 위한 환경 설정
| 설정 항목 | 권장 설정 | 설명 |
|---|---|---|
| 프로세스 생성 방식 | spawn | CUDA 텐서 공유 시 fork는 사용 불가. torch.multiprocessing.set_start_method("spawn", force=True) 사용 |
| 스레드 수 제한 | MKL_NUM_THREADS=1, OMP_NUM_THREADS=1 | 각 Trial이 내부에서 과도한 스레드를 생성하지 않도록 제한 |
| CUDA 디버깅 | CUDA_LAUNCH_BLOCKING=1 | CUDA 커널 실행을 동기 모드로 전환하여 에러 발생 지점을 정확히 추적. 성능이 크게 저하되므로 디버깅 시에만 사용 |
spawn 설정은 실수로 빠뜨리기 쉬우므로, 진입점 상단에 강제로 넣는 것이 좋다.
import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)8.3. 고급 고려사항: 대체 메모리 할당자
모든 방법을 시도한 뒤에도 문제가 남으면, LD_PRELOAD로 tcmalloc 같은 다른 할당자를 실험해볼 수 있다.
LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 python train.py단, 이것은 근본적 수정이 아니라 고급 진단 도구 또는 임시 운영 해결책으로만 간주해야 한다. 할당자를 바꿔서 안정화되었다면, 그것은 "원래 할당자에서 드러나는 메모리 관리 버그가 아직 코드에 남아 있다"는 증거다.
마치며
double free와 RayActorError는 서로 다른 증상이었지만, 인과 관계 사슬의 뿌리는 같았다. **잘못 구성된 컨테이너 환경(/dev/shm)**이 시스템을 불안정하게 만들고, PyTorch 연산 그래프의 암묵적 누수 -- 특히 시뮬레이션 객체 내부에 은닉된 그래프 참조 -- 가 메모리를 폭증시켜 최종 충돌로 이어졌다.
이 경험에서 얻은 교훈은 명확하다. 현대의 ML 시스템을 디버깅하려면 모델 코드만 볼 수 없다. 컨테이너 설정, 오케스트레이션 프레임워크, 딥러닝 라이브러리의 내부 구조까지 아우르는 시스템적 사고가 필수적이다.
최종 해결책을 요약하면:
- 환경:
/dev/shm을 가용 RAM의 30% 이상으로 설정 (--shm-size=4g) - 코드: 커스텀 모듈의 내부 상태에
.detach()적용, 손실 누적에.item()사용, step 간 상태 전이에detach().clone()패턴 - 정리:
cleanup메서드에서del->gc.collect()->torch.cuda.empty_cache()순서로 명시적 해제 - Ray:
FailureConfig(fail_fast="raise")로 에러 즉시 노출, 디버깅 완료 후tune.with_resources()로 병렬 확장 이 네 가지를 조합하면 안정적이고 재현 가능한 분산 학습 환경을 구축할 수 있다.
참고 문헌
오류 및 메모리 문제
[1] TensorFlow double free or corruption 이슈 -- GitHub
[2] /dev/shm와 과할당 이해 -- Reddit
[3] /tmp vs /dev/shm 차이 -- Super User
공유 메모리 및 컨테이너
[4] Ray Memory Management -- 공식 문서
[5] PyTorch Multiprocessing Best Practices -- 공식 문서
[6] Docker에서 /dev/shm 크기 조정 -- Stack Overflow
Ray 리소스 관리
[7] Trainable.cleanup() -- Ray Docs
[8] Ray Tune OOM at every second trial -- Stack Overflow
[9] Ray Out-of-Memory Prevention -- Ray Docs
PyTorch 메모리 및 Autograd
[10] for-loop에서 loss 누적과 누수 -- PyTorch Forums
[11] PyTorch 메모리 누수 디버깅 -- PyTorch Forums
[12] detach, no_grad, requires_grad 차이 -- PyTorch Forums
[13] 실전 detach 가이드 -- Medium
[14] detach().clone() vs clone().detach() -- PyTorch Forums
Ray 디버깅 및 구성
[15] Ray Distributed Debugger -- Ray Docs
[16] ray.init() API -- Ray Docs (local_mode deprecated 표기)
추가 자료
- memory-corruption Q&A -- Stack Overflow
- Jupyter에서
/dev/shm확장 -- Jupyter Forum - Terminated trial memory hold -- Ray Discuss
torch_shm_manager수동 GC -- Stack Overflow- PyTorch 메모리 누수 개요 -- Medium
- Ray 메모리 디버깅 가이드 -- Ray Docs
- CUDA Semantics -- PyTorch Docs
tune.with_resources()API -- Ray Docs